{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Build a strategy from scratch\n",
    "\n",
    "Welcome to the third part of the Flower federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and the Flower framework ([part 1](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html)) and we learned how strategies can be used to customize the execution on both the server and the clients ([part 2](https://flower.ai/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html)).\n",
    "\n",
    "In this notebook, we'll continue to customize the federated learning system we built previously by creating a custom version of FedAvg using the Flower framework, Flower Datasets, and PyTorch.\n",
    "\n",
    "> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Flower Discuss and the Flower Slack to connect, ask questions, and get help:\n",
    "> - [Join Flower Discuss](https://discuss.flower.ai/) We'd love to hear from you in the `Introduction` topic! If anything is unclear, post in `Flower Help - Beginners`.\n",
    "> - [Join Flower Slack](https://flower.ai/join-slack) We'd love to hear from you in the `#introductions` channel! If anything is unclear, head over to the `#questions` channel.\n",
    "\n",
    "Let's build a new `Strategy` from scratch! 🌼"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preparation\n",
    "\n",
    "Before we begin with the actual code, let's make sure that we have everything we need."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Installing dependencies\n",
    "\n",
    "First, we install the necessary packages:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q flwr[simulation] flwr-datasets[vision] torch torchvision"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have all dependencies installed, we can import everything we need for this tutorial:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "from typing import Dict, List, Optional, Tuple\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import flwr\n",
    "from flwr.client import Client, ClientApp, NumPyClient\n",
    "from flwr.common import Context\n",
    "from flwr.server import ServerApp, ServerConfig, ServerAppComponents\n",
    "from flwr.server.strategy import Strategy\n",
    "from flwr.simulation import run_simulation\n",
    "from flwr_datasets import FederatedDataset\n",
    "\n",
    "DEVICE = torch.device(\"cpu\")  # Try \"cuda\" to train on GPU\n",
    "print(f\"Training on {DEVICE}\")\n",
    "print(f\"Flower {flwr.__version__} / PyTorch {torch.__version__}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is possible to switch to a runtime that has GPU acceleration enabled (on Google Colab: `Runtime > Change runtime type > Hardware acclerator: GPU > Save`). Note, however, that Google Colab is not always able to offer GPU acceleration. If you see an error related to GPU availability in one of the following sections, consider switching back to CPU-based execution by setting `DEVICE = torch.device(\"cpu\")`. If the runtime has GPU acceleration enabled, you should see the output `Training on cuda`, otherwise it'll say `Training on cpu`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data loading\n",
    "\n",
    "Let's now load the CIFAR-10 training and test set, partition them into ten smaller datasets (each split into training and validation set), and wrap everything in their own `DataLoader`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_datasets(partition_id, num_partitions: int):\n",
    "    fds = FederatedDataset(dataset=\"cifar10\", partitioners={\"train\": num_partitions})\n",
    "    partition = fds.load_partition(partition_id)\n",
    "    # Divide data on each node: 80% train, 20% test\n",
    "    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)\n",
    "    pytorch_transforms = transforms.Compose(\n",
    "        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n",
    "    )\n",
    "\n",
    "    def apply_transforms(batch):\n",
    "        # Instead of passing transforms to CIFAR10(..., transform=transform)\n",
    "        # we will use this function to dataset.with_transform(apply_transforms)\n",
    "        # The transforms object is exactly the same\n",
    "        batch[\"img\"] = [pytorch_transforms(img) for img in batch[\"img\"]]\n",
    "        return batch\n",
    "\n",
    "    partition_train_test = partition_train_test.with_transform(apply_transforms)\n",
    "    trainloader = DataLoader(partition_train_test[\"train\"], batch_size=32, shuffle=True)\n",
    "    valloader = DataLoader(partition_train_test[\"test\"], batch_size=32)\n",
    "    testset = fds.load_split(\"test\").with_transform(apply_transforms)\n",
    "    testloader = DataLoader(testset, batch_size=32)\n",
    "    return trainloader, valloader, testloader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model training/evaluation\n",
    "\n",
    "Let's continue with the usual model definition (including `set_parameters` and `get_parameters`), training and test functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self) -> None:\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 6, 5)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = x.view(-1, 16 * 5 * 5)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "def get_parameters(net) -> List[np.ndarray]:\n",
    "    return [val.cpu().numpy() for _, val in net.state_dict().items()]\n",
    "\n",
    "\n",
    "def set_parameters(net, parameters: List[np.ndarray]):\n",
    "    params_dict = zip(net.state_dict().keys(), parameters)\n",
    "    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})\n",
    "    net.load_state_dict(state_dict, strict=True)\n",
    "\n",
    "\n",
    "def train(net, trainloader, epochs: int):\n",
    "    \"\"\"Train the network on the training set.\"\"\"\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.Adam(net.parameters())\n",
    "    net.train()\n",
    "    for epoch in range(epochs):\n",
    "        correct, total, epoch_loss = 0, 0, 0.0\n",
    "        for batch in trainloader:\n",
    "            images, labels = batch[\"img\"], batch[\"label\"]\n",
    "            images, labels = images.to(DEVICE), labels.to(DEVICE)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = net(images)\n",
    "            loss = criterion(net(images), labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            # Metrics\n",
    "            epoch_loss += loss\n",
    "            total += labels.size(0)\n",
    "            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()\n",
    "        epoch_loss /= len(trainloader.dataset)\n",
    "        epoch_acc = correct / total\n",
    "        print(f\"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}\")\n",
    "\n",
    "\n",
    "def test(net, testloader):\n",
    "    \"\"\"Evaluate the network on the entire test set.\"\"\"\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    correct, total, loss = 0, 0, 0.0\n",
    "    net.eval()\n",
    "    with torch.no_grad():\n",
    "        for batch in testloader:\n",
    "            images, labels = batch[\"img\"], batch[\"label\"]\n",
    "            images, labels = images.to(DEVICE), labels.to(DEVICE)\n",
    "            outputs = net(images)\n",
    "            loss += criterion(outputs, labels).item()\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "    loss /= len(testloader.dataset)\n",
    "    accuracy = correct / total\n",
    "    return loss, accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Flower client\n",
    "\n",
    "To implement the Flower client, we (again) create a subclass of `flwr.client.NumPyClient` and implement the three methods `get_parameters`, `fit`, and `evaluate`. Here, we also pass the `partition_id` to the client and use it log additional details. We then create an instance of `ClientApp` and pass it the `client_fn`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FlowerClient(NumPyClient):\n",
    "    def __init__(self, partition_id, net, trainloader, valloader):\n",
    "        self.partition_id = partition_id\n",
    "        self.net = net\n",
    "        self.trainloader = trainloader\n",
    "        self.valloader = valloader\n",
    "\n",
    "    def get_parameters(self, config):\n",
    "        print(f\"[Client {self.partition_id}] get_parameters\")\n",
    "        return get_parameters(self.net)\n",
    "\n",
    "    def fit(self, parameters, config):\n",
    "        print(f\"[Client {self.partition_id}] fit, config: {config}\")\n",
    "        set_parameters(self.net, parameters)\n",
    "        train(self.net, self.trainloader, epochs=1)\n",
    "        return get_parameters(self.net), len(self.trainloader), {}\n",
    "\n",
    "    def evaluate(self, parameters, config):\n",
    "        print(f\"[Client {self.partition_id}] evaluate, config: {config}\")\n",
    "        set_parameters(self.net, parameters)\n",
    "        loss, accuracy = test(self.net, self.valloader)\n",
    "        return float(loss), len(self.valloader), {\"accuracy\": float(accuracy)}\n",
    "\n",
    "\n",
    "def client_fn(context: Context) -> Client:\n",
    "    net = Net().to(DEVICE)\n",
    "    partition_id = context.node_config[\"partition-id\"]\n",
    "    num_partitions = context.node_config[\"num-partitions\"]\n",
    "    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)\n",
    "    return FlowerClient(partition_id, net, trainloader, valloader).to_client()\n",
    "\n",
    "\n",
    "# Create the ClientApp\n",
    "client = ClientApp(client_fn=client_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's test what we have so far before we continue:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_PARTITIONS = 10\n",
    "\n",
    "\n",
    "def server_fn(context: Context) -> ServerAppComponents:\n",
    "    # Configure the server for just 3 rounds of training\n",
    "    config = ServerConfig(num_rounds=3)\n",
    "    # If no strategy is provided, by default, ServerAppComponents will use FedAvg\n",
    "    return ServerAppComponents(config=config)\n",
    "\n",
    "\n",
    "# Create the ServerApp\n",
    "server = ServerApp(server_fn=server_fn)\n",
    "\n",
    "# Specify the resources each of your clients need\n",
    "# If set to none, by default, each client will be allocated 2x CPU and 0x GPUs\n",
    "backend_config = {\"client_resources\": None}\n",
    "if DEVICE.type == \"cuda\":\n",
    "    backend_config = {\"client_resources\": {\"num_gpus\": 1}}\n",
    "\n",
    "# Run simulation\n",
    "run_simulation(\n",
    "    server_app=server,\n",
    "    client_app=client,\n",
    "    num_supernodes=NUM_PARTITIONS,\n",
    "    backend_config=backend_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build a Strategy from scratch\n",
    "\n",
    "Let’s overwrite the `configure_fit` method such that it passes a higher learning rate (potentially also other hyperparameters) to the optimizer of a fraction of the clients. We will keep the sampling of the clients as it is in `FedAvg` and then change the configuration dictionary (one of the `FitIns` attributes)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Union\n",
    "\n",
    "from flwr.common import (\n",
    "    EvaluateIns,\n",
    "    EvaluateRes,\n",
    "    FitIns,\n",
    "    FitRes,\n",
    "    Parameters,\n",
    "    Scalar,\n",
    "    ndarrays_to_parameters,\n",
    "    parameters_to_ndarrays,\n",
    ")\n",
    "from flwr.server.client_manager import ClientManager\n",
    "from flwr.server.client_proxy import ClientProxy\n",
    "from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg\n",
    "\n",
    "\n",
    "class FedCustom(Strategy):\n",
    "    def __init__(\n",
    "        self,\n",
    "        fraction_fit: float = 1.0,\n",
    "        fraction_evaluate: float = 1.0,\n",
    "        min_fit_clients: int = 2,\n",
    "        min_evaluate_clients: int = 2,\n",
    "        min_available_clients: int = 2,\n",
    "    ) -> None:\n",
    "        super().__init__()\n",
    "        self.fraction_fit = fraction_fit\n",
    "        self.fraction_evaluate = fraction_evaluate\n",
    "        self.min_fit_clients = min_fit_clients\n",
    "        self.min_evaluate_clients = min_evaluate_clients\n",
    "        self.min_available_clients = min_available_clients\n",
    "\n",
    "    def __repr__(self) -> str:\n",
    "        return \"FedCustom\"\n",
    "\n",
    "    def initialize_parameters(\n",
    "        self, client_manager: ClientManager\n",
    "    ) -> Optional[Parameters]:\n",
    "        \"\"\"Initialize global model parameters.\"\"\"\n",
    "        net = Net()\n",
    "        ndarrays = get_parameters(net)\n",
    "        return ndarrays_to_parameters(ndarrays)\n",
    "\n",
    "    def configure_fit(\n",
    "        self, server_round: int, parameters: Parameters, client_manager: ClientManager\n",
    "    ) -> List[Tuple[ClientProxy, FitIns]]:\n",
    "        \"\"\"Configure the next round of training.\"\"\"\n",
    "\n",
    "        # Sample clients\n",
    "        sample_size, min_num_clients = self.num_fit_clients(\n",
    "            client_manager.num_available()\n",
    "        )\n",
    "        clients = client_manager.sample(\n",
    "            num_clients=sample_size, min_num_clients=min_num_clients\n",
    "        )\n",
    "\n",
    "        # Create custom configs\n",
    "        n_clients = len(clients)\n",
    "        half_clients = n_clients // 2\n",
    "        standard_config = {\"lr\": 0.001}\n",
    "        higher_lr_config = {\"lr\": 0.003}\n",
    "        fit_configurations = []\n",
    "        for idx, client in enumerate(clients):\n",
    "            if idx < half_clients:\n",
    "                fit_configurations.append((client, FitIns(parameters, standard_config)))\n",
    "            else:\n",
    "                fit_configurations.append(\n",
    "                    (client, FitIns(parameters, higher_lr_config))\n",
    "                )\n",
    "        return fit_configurations\n",
    "\n",
    "    def aggregate_fit(\n",
    "        self,\n",
    "        server_round: int,\n",
    "        results: List[Tuple[ClientProxy, FitRes]],\n",
    "        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],\n",
    "    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:\n",
    "        \"\"\"Aggregate fit results using weighted average.\"\"\"\n",
    "\n",
    "        weights_results = [\n",
    "            (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)\n",
    "            for _, fit_res in results\n",
    "        ]\n",
    "        parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))\n",
    "        metrics_aggregated = {}\n",
    "        return parameters_aggregated, metrics_aggregated\n",
    "\n",
    "    def configure_evaluate(\n",
    "        self, server_round: int, parameters: Parameters, client_manager: ClientManager\n",
    "    ) -> List[Tuple[ClientProxy, EvaluateIns]]:\n",
    "        \"\"\"Configure the next round of evaluation.\"\"\"\n",
    "        if self.fraction_evaluate == 0.0:\n",
    "            return []\n",
    "        config = {}\n",
    "        evaluate_ins = EvaluateIns(parameters, config)\n",
    "\n",
    "        # Sample clients\n",
    "        sample_size, min_num_clients = self.num_evaluation_clients(\n",
    "            client_manager.num_available()\n",
    "        )\n",
    "        clients = client_manager.sample(\n",
    "            num_clients=sample_size, min_num_clients=min_num_clients\n",
    "        )\n",
    "\n",
    "        # Return client/config pairs\n",
    "        return [(client, evaluate_ins) for client in clients]\n",
    "\n",
    "    def aggregate_evaluate(\n",
    "        self,\n",
    "        server_round: int,\n",
    "        results: List[Tuple[ClientProxy, EvaluateRes]],\n",
    "        failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],\n",
    "    ) -> Tuple[Optional[float], Dict[str, Scalar]]:\n",
    "        \"\"\"Aggregate evaluation losses using weighted average.\"\"\"\n",
    "\n",
    "        if not results:\n",
    "            return None, {}\n",
    "\n",
    "        loss_aggregated = weighted_loss_avg(\n",
    "            [\n",
    "                (evaluate_res.num_examples, evaluate_res.loss)\n",
    "                for _, evaluate_res in results\n",
    "            ]\n",
    "        )\n",
    "        metrics_aggregated = {}\n",
    "        return loss_aggregated, metrics_aggregated\n",
    "\n",
    "    def evaluate(\n",
    "        self, server_round: int, parameters: Parameters\n",
    "    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:\n",
    "        \"\"\"Evaluate global model parameters using an evaluation function.\"\"\"\n",
    "\n",
    "        # Let's assume we won't perform the global model evaluation on the server side.\n",
    "        return None\n",
    "\n",
    "    def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:\n",
    "        \"\"\"Return sample size and required number of clients.\"\"\"\n",
    "        num_clients = int(num_available_clients * self.fraction_fit)\n",
    "        return max(num_clients, self.min_fit_clients), self.min_available_clients\n",
    "\n",
    "    def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:\n",
    "        \"\"\"Use a fraction of available clients for evaluation.\"\"\"\n",
    "        num_clients = int(num_available_clients * self.fraction_evaluate)\n",
    "        return max(num_clients, self.min_evaluate_clients), self.min_available_clients"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The only thing left is to use the newly created custom Strategy `FedCustom` when starting the experiment:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def server_fn(context: Context) -> ServerAppComponents:\n",
    "    # Configure the server for just 3 rounds of training\n",
    "    config = ServerConfig(num_rounds=3)\n",
    "    return ServerAppComponents(\n",
    "        config=config,\n",
    "        strategy=FedCustom(),  # <-- pass the new strategy here\n",
    "    )\n",
    "\n",
    "\n",
    "# Run simulation\n",
    "run_simulation(\n",
    "    server_app=server,\n",
    "    client_app=client,\n",
    "    num_supernodes=NUM_PARTITIONS,\n",
    "    backend_config=backend_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Recap\n",
    "\n",
    "In this notebook, we’ve seen how to implement a custom strategy. A custom strategy enables granular control over client node configuration, result aggregation, and more. To define a custom strategy, you only have to overwrite the abstract methods of the (abstract) base class `Strategy`. To make custom strategies even more powerful, you can pass custom functions to the constructor of your new class (`__init__`) and then call these functions whenever needed. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next steps\n",
    "\n",
    "Before you continue, make sure to join the Flower community on Flower Discuss ([Join Flower Discuss](https://discuss.flower.ai)) and on Slack ([Join Slack](https://flower.ai/join-slack/)).\n",
    "\n",
    "There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n",
    "\n",
    "The [Flower Federated Learning Tutorial - Part 4](https://flower.ai/docs/framework/tutorial-customize-the-client-pytorch.html) introduces `Client`, the flexible API underlying `NumPyClient`."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Flower-3-Building-a-Strategy-PyTorch.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3.7.12 64-bit ('flower-3.7.12')",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
